In [2]:
torch.setdefaulttensortype('torch.FloatTensor')
Out[2]:
In [3]:
trainset = torch.load('cifar10-train.t7')
testset = torch.load('cifar10-test.t7')
classes = {'airplane', 'automobile', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck'}
In [4]:
print(trainset)
Out[4]:
In [5]:
print(testset)
Out[5]:
In [6]:
function tabulate(A, f)
local idx = {}
local ndims = A:dim()
local dim = A:size()
idx[ndims] = 0
for i=1, (ndims - 1) do
idx[i] = 1
end
return A:apply(function()
for i=ndims, 0, -1 do
idx[i] = idx[i] + 1
if idx[i] <= dim[i] then
break
end
idx[i] = 1
end
return f(unpack(idx))
end)
end
In [17]:
NearestNeighbor = {}
function NearestNeighbor:new(o)
o = nearestNeighbor or {}
setmetatable(o, self)
self.__index = self
return o
end
function NearestNeighbor:train(trainset)
NearestNeighbor.trainset = trainset or {}
end
function NearestNeighbor:predict(testset, is_l2, k)
is_l2 = is_l2 or false
k = k or 1
neighbors = {}
if is_l2 then
print('l2', k)
else
print('l1', k)
end
-- for i=1, testset.data:size(1) do
for i=1, 10 do
local min_diff = 9999999
local cur_neighbors = {}
-- for j=1, trainset.data:size(1) do
for j=1, 1000 do
local diff = trainset.data[j] - testset.data[i]
diff = torch.FloatTensor(diff:size()):copy(diff)
if is_l2 then
diff_sum = math.sqrt(torch.pow(diff,2):sum())
else
diff_sum = torch.abs(diff):sum()
end
table.insert(cur_neighbors, {j, diff_sum})
end
table.sort(cur_neighbors, function(a,b) return a[2] < b[2] end)
-- print(cur_neighbors)
local counter, max, answer = {}, 0, nil
for i=1,k do
local j = cur_neighbors[i][1]
local label = trainset.label[j]
if counter[label] == nil then
counter[label] = 1
else
counter[label] = counter[label] + 1
end
if max < counter[label] then
max = counter[label]
answer = label
end
end
-- print(counter, answer)
-- print("============")
table.insert(neighbors, answer)
end
return neighbors
end
nn = NearestNeighbor:new()
nn:train(trainset)
t = os.clock()
results1 = nn:predict(testset, false)
print("Original time:", os.difftime(os.clock(), t))
t = os.clock()
results2 = nn:predict(testset, true)
print("Apply method:", os.difftime(os.clock(), t))
t = os.clock()
results3 = nn:predict(testset, false, 5)
print("Original time:", os.difftime(os.clock(), t))
t = os.clock()
results4 = nn:predict(testset, true, 5)
print("Apply method:", os.difftime(os.clock(), t))
function compare(results, testset)
count = 0
for idx, l in ipairs(results) do
-- print(results[idx], testset.label[idx])
if results[idx] == testset.label[idx] then
count = count + 1
end
end
return count/#results*100
end
print(string.format('accuracy: %.2f', compare(results1, testset)))
print(string.format('accuracy: %.2f', compare(results2, testset)))
print(string.format('accuracy: %.2f', compare(results3, testset)))
print(string.format('accuracy: %.2f', compare(results4, testset)))
Out[17]:
In [ ]: